{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mpl_toolkits.mplot3d import Axes3D\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "import matplotlib.pyplot as plt # plotting\n",
    "import numpy as np # linear algebra\n",
    "import os # accessing directory structure\n",
    "import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)\n",
    "import csv\n",
    "import re\n",
    "\n",
    "import jieba\n",
    "from sklearn.feature_extraction.text import TfidfTransformer\n",
    "from sklearn.feature_extraction.text import TfidfVectorizer\n",
    "import gensim\n",
    "from gensim.models import Word2Vec\n",
    "from sklearn.preprocessing import scale\n",
    "import multiprocessing\n",
    "\n",
    "from snownlp import SnowNLP\n",
    "import jieba.analyse"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>text</th>\n",
       "      <th>class</th>\n",
       "      <th>positive</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>index</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>﻿18年结婚 哈哈哈</td>\n",
       "      <td>0</td>\n",
       "      <td>0.900696</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2017最后顿大餐吃完两人世界明年就是三个人一起啦许下生日愿望️希望一家人都能顺利平安健康🏻🏻🏻</td>\n",
       "      <td>1</td>\n",
       "      <td>0.999904</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>意盎然的季节！祝愿大家都生机勃勃，郁郁葱葱！</td>\n",
       "      <td>2</td>\n",
       "      <td>0.736431</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>2017 遇见挚友 遇见我老公 结了婚有了小芒果     希望2018也超级美好️</td>\n",
       "      <td>3</td>\n",
       "      <td>0.983905</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>2018.1.1</td>\n",
       "      <td>4</td>\n",
       "      <td>0.500000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>2018加油！</td>\n",
       "      <td>5</td>\n",
       "      <td>0.895319</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>2018年做一个更加真实的自己。️</td>\n",
       "      <td>3</td>\n",
       "      <td>0.783433</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>2018年的第一天，完美的错过了一辆公交车。 德州</td>\n",
       "      <td>6</td>\n",
       "      <td>0.934181</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>2018年目标1.赚钱买房2.谈场恋爱，遇到对的人就结婚3.拥有一副健康的身体4.学会一种乐...</td>\n",
       "      <td>7</td>\n",
       "      <td>0.999799</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>2018年第一个假期：元旦，就这么过去了，感冒咳嗽发高烧给这个元旦带来了不一样的节日，好快呀...</td>\n",
       "      <td>8</td>\n",
       "      <td>0.733896</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                    text  class  positive\n",
       "index                                                                    \n",
       "0                                             ﻿18年结婚 哈哈哈      0  0.900696\n",
       "1       2017最后顿大餐吃完两人世界明年就是三个人一起啦许下生日愿望️希望一家人都能顺利平安健康🏻🏻🏻      1  0.999904\n",
       "2                                 意盎然的季节！祝愿大家都生机勃勃，郁郁葱葱！      2  0.736431\n",
       "3              2017 遇见挚友 遇见我老公 结了婚有了小芒果     希望2018也超级美好️      3  0.983905\n",
       "4                                               2018.1.1      4  0.500000\n",
       "5                                                2018加油！      5  0.895319\n",
       "6                                      2018年做一个更加真实的自己。️      3  0.783433\n",
       "7                              2018年的第一天，完美的错过了一辆公交车。 德州      6  0.934181\n",
       "8      2018年目标1.赚钱买房2.谈场恋爱，遇到对的人就结婚3.拥有一副健康的身体4.学会一种乐...      7  0.999799\n",
       "9      2018年第一个假期：元旦，就这么过去了，感冒咳嗽发高烧给这个元旦带来了不一样的节日，好快呀...      8  0.733896"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dff = pd.read_csv(\"C:/Users/Kai/Desktop/171840708_IntroDM_MiningChallenge/mining-challenge-for-nju-introdm-2019/Mining Challenge Dataset/train.csv\",index_col=0)\n",
    "dff['text'] = dff['text'].fillna('')\n",
    "dff.head(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n",
      "[' ', '  ', '~', '。', '，', '…', '～', '!', '\"', '#', '$', '%', '&', \"'\", '(', ')', '*', '+', ',', '-', '--', '.', ':', '://', '::', ';', '<', '=', '>', '>>', '?', '@', 'A', 'Lex', '', '\\\\', '', '^', '_', '`', 'exp', 'sub', 'sup', '|', '}', '~', '~~~~', '·', '×', '×××', 'Δ', 'Ψ', 'γ', 'μ', 'φ', 'φ．', 'В', '—', '——', '———', '‘', '’', '’‘', '“', '”', '”，', '…', '……', '…………………………………………………③', '′∈', '′｜', '℃', 'Ⅲ', '↑', '→', '∈', '∪φ∈', '≈', '①', '②', '②ｃ', '③', '③', '④', '⑤', '⑥', '⑦', '⑧', '⑨', '⑩', '──', '■', '▲', '\\u3000', '、', '。', '〈', '〉', '《', '》']\n"
     ]
    }
   ],
   "source": [
    "def stopwordslist():\n",
    "    f = open(\"C:/Users/Kai/Desktop/stop.txt\", \"r\")\n",
    "    line = f.readline()\n",
    "    stopwords = []\n",
    "    index = 0\n",
    "    while line:\n",
    "        if index % 1000 == 0:\n",
    "            print(index)\n",
    "        index += 1\n",
    "        line = line.replace('\\n', '')\n",
    "        line = line.replace('[', '')\n",
    "        line = line.replace(']', '')\n",
    "        line = line.replace('］', '')\n",
    "        line = line.replace('［', '')\n",
    "        \n",
    "        stopwords.append(line)\n",
    "        line = f.readline()\n",
    "\n",
    "    print(stopwords[:100])\n",
    "    return stopwords\n",
    "\n",
    "# 创建一个停用词列表\n",
    "stopwords = stopwordslist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 对句子进行中文分词\n",
    "def seg_depart(sentence):\n",
    "    # 对文档中的每一行进行中文分词\n",
    "    sentence_depart = jieba.cut(sentence.strip())\n",
    "    # 输出结果为outstr\n",
    "    outstr = ''\n",
    "    # 去停用词\n",
    "    for word in sentence_depart:\n",
    "        if word not in stopwords:\n",
    "            if word != '\\t':\n",
    "                outstr += word\n",
    "                outstr += \" \"\n",
    "    return outstr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "clas = dff['class'].values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>text</th>\n",
       "      <th>class</th>\n",
       "      <th>positive</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>index</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>我是正面哦</td>\n",
       "      <td>0</td>\n",
       "      <td>0.347826</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>爱是恒久忍耐，又有恩慈。爱是不嫉妒，不自夸，不张狂，不轻易发怒。不计算人的恶。凡事包容。凡事...</td>\n",
       "      <td>0</td>\n",
       "      <td>0.496333</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>讨厌死了，上班上班上班不停的上班我真的超级累。什么都不干还是超级超级累。</td>\n",
       "      <td>0</td>\n",
       "      <td>0.000422</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>矮马大半夜的放肌肉男不让人睡觉了</td>\n",
       "      <td>0</td>\n",
       "      <td>0.409895</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>谢谢陈先生。</td>\n",
       "      <td>0</td>\n",
       "      <td>0.768959</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>我的2016要早点睡别熬夜</td>\n",
       "      <td>0</td>\n",
       "      <td>0.625607</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>周锐锐哥！爱你</td>\n",
       "      <td>0</td>\n",
       "      <td>0.970187</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>塞尼亚岛</td>\n",
       "      <td>0</td>\n",
       "      <td>0.500000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>只可惜没能去现场</td>\n",
       "      <td>0</td>\n",
       "      <td>0.100791</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>自从发现这个号都处于一种忍不住不看看了睡不着的状态</td>\n",
       "      <td>0</td>\n",
       "      <td>0.355194</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                    text  class  positive\n",
       "index                                                                    \n",
       "0                                                  我是正面哦      0  0.347826\n",
       "1      爱是恒久忍耐，又有恩慈。爱是不嫉妒，不自夸，不张狂，不轻易发怒。不计算人的恶。凡事包容。凡事...      0  0.496333\n",
       "2                   讨厌死了，上班上班上班不停的上班我真的超级累。什么都不干还是超级超级累。      0  0.000422\n",
       "3                                       矮马大半夜的放肌肉男不让人睡觉了      0  0.409895\n",
       "4                                                 谢谢陈先生。      0  0.768959\n",
       "5                                          我的2016要早点睡别熬夜      0  0.625607\n",
       "6                                                周锐锐哥！爱你      0  0.970187\n",
       "7                                                   塞尼亚岛      0  0.500000\n",
       "8                                               只可惜没能去现场      0  0.100791\n",
       "9                              自从发现这个号都处于一种忍不住不看看了睡不着的状态      0  0.355194"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dfTest = pd.read_csv(\"C:/Users/Kai/Desktop/171840708_IntroDM_MiningChallenge/mining-challenge-for-nju-introdm-2019/Mining Challenge Dataset/test.csv\",index_col=0)\n",
    "dfTest['text'] = dfTest['text'].fillna('')\n",
    "dfTest.head(10)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Building prefix dict from the default dictionary ...\n",
      "Loading model from cache C:\\Users\\Kai\\AppData\\Local\\Temp\\jieba.cache\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading model cost 0.879 seconds.\n",
      "Prefix dict has been built succesfully.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "100000\n",
      "200000\n",
      "300000\n",
      "400000\n",
      "500000\n",
      "600000\n",
      "700000\n",
      "800000\n"
     ]
    }
   ],
   "source": [
    "# 分词\n",
    "sen = dff['text'].values\n",
    "\n",
    "for i in range(len(sen)):\n",
    "    if i % 100000 == 0:\n",
    "        print(i)\n",
    "    sen[i] = seg_depart(sen[i])\n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n",
      "100000\n"
     ]
    }
   ],
   "source": [
    "senTest = dfTest['text'].values\n",
    "\n",
    "for i in range(len(senTest)):\n",
    "    if i % 100000 == 0:\n",
    "        print(i)\n",
    "    senTest[i] = seg_depart(senTest[i])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['我 是 正面 哦 '\n",
      " '爱是 恒久 忍耐 又 有恩慈 爱是 不嫉妒 不 自夸 不 张狂 不 轻易 发怒 不 计算 人 的 恶 凡事 包容 凡事 相信 凡事 盼望 凡事 忍耐 爱是 永不 止息 '\n",
      " '讨厌 死 了 上班 上班 上班 不停 的 上班 我 真的 超级 累 什么 都 不 干 还是 超级 超级 累 '\n",
      " '矮马 大半夜 的 放 肌肉男 不让 人 睡觉 了 ' '谢谢 陈先生 ' '我 的 2016 要 早点 睡别 熬夜 ' '周锐 锐哥 爱 你 '\n",
      " '塞 尼亚岛 ' '只 可惜 没能 去 现场 ' '自从 发现 这个 号 都 处于 一种 忍不住 不 看看 了 睡不着 的 状态 '\n",
      " '真系 咁 钟意 音乐 咩 '\n",
      " '感恩 2 续 他们 都 会 过 得 很 幸福 甜蜜 爸爸 的 身体 也 越来越 健壮 健康 妈妈 也 越来越 温柔 越 女人 我 自己 也 越来越 漂亮 皮肤 好好 非常 水润 皮肤 非常 光滑 我 弟弟 也 越来越 帅 越来越 思想 成熟 做事 非常 稳重 也 越来越 让 家人 开心 在 南昌 明年 一定 会 有 到 我 的 单身公寓 我 明年 一定 会 拿到 我 的 粉车 '\n",
      " '迷尚 自然 的 主页 ' '问叹 女王 权杖 口红 我 最 爱 的 口红 是 口红 又 是 装饰品 '\n",
      " '有个 顺序 得 先 读书 然后 才能 多 走走 否则 行再 多路 也 是 个 邮差 音乐 也 是 一样 我 倒 是 也 想 施施然 上台 去 可是 要 被 踹 下来 的 呀 预祝 巡演 成功 '\n",
      " '年终 福利 ' '声音 好好 听 '\n",
      " '少年 迪玛希 谁家 翩翩少年 郎 横空出世 迷人眼 着 调 专访 少年 迪玛希 谁家 翩翩少年 郎 横空出世 迷人眼 着 调 专访 '\n",
      " '喜欢 的 紫薯 甜品店 来 了 ' '我 不是 好惹 的 第 12 名 ' '一天 比 一天 像 公主 梦 都 被 满足 '\n",
      " '有 你 在 身边 很 心安 去 校医 室 有人 陪 去 体检 有人 陪 干什么 你 都 在 很快 又 不累 '\n",
      " '果然 全世界 女孩子 都 是 一样 的 ... 这 看 脸 的 世界 '\n",
      " '11 月 7 日 20 00 上 新 开拍 亲们 来 捧场 哦 上 新 当晚 有 给 力 优惠 还有 神秘 福袋 哦 '\n",
      " '吉林 百嘉 门将 原 国家 沙滩 足球队 主力 门将 温廷元 扑出 了 对方 王凯 的 点球 ' '午间 运动 ' '湖南 张家界 天门山 '\n",
      " '一杯 红酒 一盘 残羹剩饭 几块 蛋糕 当做 大餐 我 肯定 醉 了 ' '萌萌 哒 的 我 '\n",
      " '发现 一些 古懂 你们 以前 是 用 这种 真正 的 幻灯片 的 吗 ']\n"
     ]
    }
   ],
   "source": [
    "print(senTest[:30])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from sklearn.model_selection import train_test_split\n",
    "# X_train, X_test, y_train, y_test = train_test_split(sen, clas, test_size=0.1, random_state=42)   ######"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# vectorizer = TfidfVectorizer(sublinear_tf=True, max_df=0.5)\n",
    "# transformer = TfidfTransformer()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# tf_X_train = vectorizer.fit_transform(X_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "# tf_X_test = vectorizer.transform(X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "# tf_Test = vectorizer.transform(senTest)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from sklearn.naive_bayes import BernoulliNB, ComplementNB, MultinomialNB\n",
    "# maxxi = 0\n",
    "# maxxscore = 0\n",
    "# for i in np.arange(10, 20, 0.5):\n",
    "#     mnb = ComplementNB(alpha=i)\n",
    "#     mnb.fit(tf_X_train, y_train)\n",
    "#     print(mnb.score(tf_X_test,y_test), i)\n",
    "#     if maxxscore < mnb.score(tf_X_test,y_test):\n",
    "#         maxxscore = mnb.score(tf_X_test,y_test)\n",
    "#         maxxi = i\n",
    "\n",
    "# print(maxxscore, maxxi)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "# mnb = ComplementNB(alpha=11.5)\n",
    "# mnb.fit(tf_X_train, y_train)\n",
    "# print(mnb.score(tf_X_test,y_test), 0.1)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "# pred = mnb.predict(tf_Test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "# csvFile = open('C:/Users/Kai/Desktop/171840708_IntroDM_MiningChallenge/mining-challenge-for-nju-introdm-2019/Mining Challenge Dataset/last0.csv','w', newline='', encoding='UTF-8') # 设置newline，否则两行之间会空一行\n",
    "# writer = csv.writer(csvFile)\n",
    "\n",
    "# writer.writerow(['ID', 'Expected'])\n",
    "# for i in range(len(pred)):\n",
    "#     if i % 50000 == 0:\n",
    "#         print(i)\n",
    "#     writer.writerow([int(i), int(pred[i])])\n",
    "    \n",
    "# csvFile.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from sklearn.svm import LinearSVC\n",
    "# model = LinearSVC(penalty='l1', dual=False, tol=1e-3)\n",
    "# model.fit(tf_X_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "# print(model.score(tf_X_test,y_test))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "E:\\Anaconda3\\lib\\site-packages\\h5py\\__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
      "  from ._conv import register_converters as _register_converters\n",
      "Using TensorFlow backend.\n",
      "E:\\Anaconda3\\lib\\site-packages\\keras_preprocessing\\text.py:178: UserWarning: The `nb_words` argument in `Tokenizer` has been renamed `num_words`.\n",
      "  warnings.warn('The `nb_words` argument in `Tokenizer` '\n"
     ]
    }
   ],
   "source": [
    "# libraries\n",
    "\n",
    "import numpy as np # linear algebra\n",
    "import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)\n",
    "import matplotlib.pyplot as plt\n",
    "np.random.seed(32)\n",
    "\n",
    "\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import roc_auc_score\n",
    "from sklearn.manifold import TSNE\n",
    "\n",
    "from keras.preprocessing.text import Tokenizer\n",
    "from keras.preprocessing.sequence import pad_sequences\n",
    "from keras.layers import LSTM, Conv1D, MaxPooling1D, Dropout\n",
    "from keras.utils.np_utils import to_categorical\n",
    "\n",
    "\n",
    "%matplotlib inline\n",
    "MAX_NB_WORDS = 20000\n",
    "# finally, vectorize the text samples into a 2D integer tensor\n",
    "tokenizer = Tokenizer(nb_words=MAX_NB_WORDS, char_level=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer.fit_on_texts(sen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "sequences = tokenizer.texts_to_sequences(sen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "sequences_test = tokenizer.texts_to_sequences(senTest)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "MAX_SEQUENCE_LENGTH = 300\n",
    "\n",
    "# pad sequences with 0s\n",
    "x_train = pad_sequences(sequences, maxlen=MAX_SEQUENCE_LENGTH)\n",
    "x_test = pad_sequences(sequences_test, maxlen=MAX_SEQUENCE_LENGTH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>text</th>\n",
       "      <th>positive</th>\n",
       "      <th>0</th>\n",
       "      <th>1</th>\n",
       "      <th>2</th>\n",
       "      <th>3</th>\n",
       "      <th>4</th>\n",
       "      <th>5</th>\n",
       "      <th>6</th>\n",
       "      <th>7</th>\n",
       "      <th>...</th>\n",
       "      <th>62</th>\n",
       "      <th>63</th>\n",
       "      <th>64</th>\n",
       "      <th>65</th>\n",
       "      <th>66</th>\n",
       "      <th>67</th>\n",
       "      <th>68</th>\n",
       "      <th>69</th>\n",
       "      <th>70</th>\n",
       "      <th>71</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>index</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>﻿ 18 年 结婚 哈哈哈</td>\n",
       "      <td>0.900696</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2017 最后 顿 大餐 吃 完 两人 世界 明年 就是 三个 人 一起 啦 许下 生日 愿...</td>\n",
       "      <td>0.999904</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>意 盎然 的 季节 祝愿 大家 都 生机勃勃 郁郁葱葱</td>\n",
       "      <td>0.736431</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>2017 遇见 挚友 遇见 我 老公 结了婚 有 了 小 芒果 希望 2018 也 超级 美...</td>\n",
       "      <td>0.983905</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>2018.1 1</td>\n",
       "      <td>0.500000</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 74 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                    text  positive  0  1  2  \\\n",
       "index                                                                         \n",
       "0                                         ﻿ 18 年 结婚 哈哈哈   0.900696  1  0  0   \n",
       "1      2017 最后 顿 大餐 吃 完 两人 世界 明年 就是 三个 人 一起 啦 许下 生日 愿...  0.999904  0  1  0   \n",
       "2                           意 盎然 的 季节 祝愿 大家 都 生机勃勃 郁郁葱葱   0.736431  0  0  1   \n",
       "3      2017 遇见 挚友 遇见 我 老公 结了婚 有 了 小 芒果 希望 2018 也 超级 美...  0.983905  0  0  0   \n",
       "4                                              2018.1 1   0.500000  0  0  0   \n",
       "\n",
       "       3  4  5  6  7 ...  62  63  64  65  66  67  68  69  70  71  \n",
       "index                ...                                          \n",
       "0      0  0  0  0  0 ...   0   0   0   0   0   0   0   0   0   0  \n",
       "1      0  0  0  0  0 ...   0   0   0   0   0   0   0   0   0   0  \n",
       "2      0  0  0  0  0 ...   0   0   0   0   0   0   0   0   0   0  \n",
       "3      1  0  0  0  0 ...   0   0   0   0   0   0   0   0   0   0  \n",
       "4      0  1  0  0  0 ...   0   0   0   0   0   0   0   0   0   0  \n",
       "\n",
       "[5 rows x 74 columns]"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "one_hot = pd.get_dummies(dff[\"class\"])\n",
    "dff.drop(['class'], axis=1, inplace=True)\n",
    "dff = pd.concat([dff,one_hot], axis=1)\n",
    "dff.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {},
   "outputs": [],
   "source": [
    "# a, b, yTrain, yTest = train_test_split(sen, y_train, test_size=0.1, random_state=42)   ######"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[1 0 0 ... 0 0 0]\n",
      " [0 1 0 ... 0 0 0]\n",
      " [0 0 1 ... 0 0 0]\n",
      " ...\n",
      " [0 0 0 ... 0 0 0]\n",
      " [0 0 0 ... 0 0 0]\n",
      " [0 0 0 ... 0 0 0]]\n"
     ]
    }
   ],
   "source": [
    "y_train = dff.drop(['text', 'positive'],axis=1).values\n",
    "print(y_train)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 下面的cell是 分数为 0.7+ 的模型源代码\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "from keras.layers import Dense, Input, Flatten\n",
    "from keras.layers import GlobalAveragePooling1D, Embedding\n",
    "from keras.models import Model\n",
    "\n",
    "EMBEDDING_DIM = 100\n",
    "N_CLASSES = 72\n",
    "\n",
    "# input: a sequence of MAX_SEQUENCE_LENGTH integers\n",
    "sequence_input = Input(shape=(MAX_SEQUENCE_LENGTH,), dtype='int32')\n",
    "\n",
    "embedding_layer = Embedding(MAX_NB_WORDS, EMBEDDING_DIM,\n",
    "                            input_length=MAX_SEQUENCE_LENGTH,\n",
    "                            trainable=True)\n",
    "embedded_sequences = embedding_layer(sequence_input)\n",
    "\n",
    "average = GlobalAveragePooling1D()(embedded_sequences)\n",
    "predictions = Dense(N_CLASSES, activation='softmax')(average)\n",
    "\n",
    "model = Model(sequence_input, predictions)\n",
    "model.compile(loss='categorical_crossentropy',\n",
    "              optimizer='adam', metrics=['acc'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "input_1 (InputLayer)         (None, 300)               0         \n",
      "_________________________________________________________________\n",
      "embedding_1 (Embedding)      (None, 300, 100)          2000000   \n",
      "_________________________________________________________________\n",
      "global_average_pooling1d_1 ( (None, 100)               0         \n",
      "_________________________________________________________________\n",
      "dense_1 (Dense)              (None, 72)                7272      \n",
      "=================================================================\n",
      "Total params: 2,007,272\n",
      "Trainable params: 2,007,272\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 776847 samples, validate on 86317 samples\n",
      "Epoch 1/2\n",
      "776847/776847 [==============================] - 464s 597us/step - loss: 3.7645 - acc: 0.1069 - val_loss: 3.7131 - val_acc: 0.1204\n",
      "Epoch 2/2\n",
      "776847/776847 [==============================] - 437s 562us/step - loss: 3.6407 - acc: 0.1374 - val_loss: 3.6173 - val_acc: 0.1411\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<keras.callbacks.History at 0x1ce3e392400>"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.fit(x_train, y_train, validation_split=0.1, epochs=2, batch_size=128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 776847 samples, validate on 86317 samples\n",
      "Epoch 1/2\n",
      "776847/776847 [==============================] - 443s 570us/step - loss: 3.5505 - acc: 0.1555 - val_loss: 3.5596 - val_acc: 0.1546\n",
      "Epoch 2/2\n",
      "776847/776847 [==============================] - 447s 576us/step - loss: 3.4915 - acc: 0.1659 - val_loss: 3.5098 - val_acc: 0.1678\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<keras.callbacks.History at 0x1ce3e392d68>"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.fit(x_train, y_train, validation_split=0.1, epochs=2, batch_size=128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 776847 samples, validate on 86317 samples\n",
      "Epoch 1/1\n",
      "776847/776847 [==============================] - 449s 578us/step - loss: 3.4495 - acc: 0.1732 - val_loss: 3.4994 - val_acc: 0.1659\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<keras.callbacks.History at 0x1ce3e392c88>"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.fit(x_train, y_train, validation_split=0.1, epochs=1, batch_size=128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "res = pad_sequences(sequences_test, maxlen=MAX_SEQUENCE_LENGTH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "pred = model.predict(res)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "result = np.argmax(pred, axis = 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n",
      "50000\n",
      "100000\n",
      "150000\n"
     ]
    }
   ],
   "source": [
    "# 写入文件\n",
    "csvFile = open('C:/Users/Kai/Desktop/171840708_IntroDM_MiningChallenge/mining-challenge-for-nju-introdm-2019/Mining Challenge Dataset/1.csv','w', newline='', encoding='UTF-8') # 设置newline，否则两行之间会空一行\n",
    "writer = csv.writer(csvFile)\n",
    "\n",
    "writer.writerow(['ID', 'Expected'])\n",
    "for i in range(len(result)):\n",
    "    if i % 50000 == 0:\n",
    "        print(i)\n",
    "    writer.writerow([int(i), int(result[i])])\n",
    "    \n",
    "csvFile.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.save('my_model_1.h5')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 776847 samples, validate on 86317 samples\n",
      "Epoch 1/1\n",
      "776847/776847 [==============================] - 232s 299us/step - loss: 3.4180 - acc: 0.1784 - val_loss: 3.4784 - val_acc: 0.1712\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<keras.callbacks.History at 0x1ce5163f940>"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.fit(x_train, y_train, validation_split=0.1, epochs=1, batch_size=256)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n",
      "50000\n",
      "100000\n",
      "150000\n"
     ]
    }
   ],
   "source": [
    "pred = model.predict(res)\n",
    "result = np.argmax(pred, axis = 1)\n",
    "\n",
    "# 写入文件\n",
    "csvFile = open('C:/Users/Kai/Desktop/171840708_IntroDM_MiningChallenge/mining-challenge-for-nju-introdm-2019/Mining Challenge Dataset/2.csv','w', newline='', encoding='UTF-8') # 设置newline，否则两行之间会空一行\n",
    "writer = csv.writer(csvFile)\n",
    "\n",
    "writer.writerow(['ID', 'Expected'])\n",
    "for i in range(len(result)):\n",
    "    if i % 50000 == 0:\n",
    "        print(i)\n",
    "    writer.writerow([int(i), int(result[i])])\n",
    "    \n",
    "csvFile.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.save(\"2.h5\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 776847 samples, validate on 86317 samples\n",
      "Epoch 1/1\n",
      "776847/776847 [==============================] - 250s 322us/step - loss: 3.3982 - acc: 0.1813 - val_loss: 3.4675 - val_acc: 0.1734\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<keras.callbacks.History at 0x1ce5163fef0>"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.fit(x_train, y_train, validation_split=0.1, epochs=1, batch_size=256)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n",
      "50000\n",
      "100000\n",
      "150000\n"
     ]
    }
   ],
   "source": [
    "pred = model.predict(res)\n",
    "result = np.argmax(pred, axis = 1)\n",
    "\n",
    "# 写入文件\n",
    "csvFile = open('C:/Users/Kai/Desktop/171840708_IntroDM_MiningChallenge/mining-challenge-for-nju-introdm-2019/Mining Challenge Dataset/3.csv','w', newline='', encoding='UTF-8') # 设置newline，否则两行之间会空一行\n",
    "writer = csv.writer(csvFile)\n",
    "\n",
    "writer.writerow(['ID', 'Expected'])\n",
    "for i in range(len(result)):\n",
    "    if i % 50000 == 0:\n",
    "        print(i)\n",
    "    writer.writerow([int(i), int(result[i])])\n",
    "    \n",
    "csvFile.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 776847 samples, validate on 86317 samples\n",
      "Epoch 1/1\n",
      "776847/776847 [==============================] - 505s 650us/step - loss: 3.3785 - acc: 0.1845 - val_loss: 3.4588 - val_acc: 0.1735\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<keras.callbacks.History at 0x1ce5163ffd0>"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.fit(x_train, y_train, validation_split=0.1, epochs=1, batch_size=128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 776847 samples, validate on 86317 samples\n",
      "Epoch 1/2\n",
      "776847/776847 [==============================] - 630s 811us/step - loss: 3.3545 - acc: 0.1882 - val_loss: 3.4532 - val_acc: 0.1756\n",
      "Epoch 2/2\n",
      "776847/776847 [==============================] - 606s 780us/step - loss: 3.3336 - acc: 0.1918 - val_loss: 3.4609 - val_acc: 0.1732\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<keras.callbacks.History at 0x1ce5163fe10>"
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.fit(x_train, y_train, validation_split=0.1, epochs=2, batch_size=128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 776847 samples, validate on 86317 samples\n",
      "Epoch 1/1\n",
      "776847/776847 [==============================] - 503s 647us/step - loss: 3.3144 - acc: 0.1953 - val_loss: 3.4599 - val_acc: 0.1736\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<keras.callbacks.History at 0x1ce5163f780>"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.fit(x_train, y_train, validation_split=0.1, epochs=1, batch_size=128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 776847 samples, validate on 86317 samples\n",
      "Epoch 1/1\n",
      "776847/776847 [==============================] - 499s 642us/step - loss: 3.2965 - acc: 0.1981 - val_loss: 3.4537 - val_acc: 0.1725\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<keras.callbacks.History at 0x1ce517e0d30>"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.fit(x_train, y_train, validation_split=0.1, epochs=1, batch_size=128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 776847 samples, validate on 86317 samples\n",
      "Epoch 1/2\n",
      "776847/776847 [==============================] - 508s 654us/step - loss: 3.2792 - acc: 0.2008 - val_loss: 3.4511 - val_acc: 0.1769\n",
      "Epoch 2/2\n",
      "776847/776847 [==============================] - 497s 640us/step - loss: 3.2625 - acc: 0.2034 - val_loss: 3.4543 - val_acc: 0.1739\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<keras.callbacks.History at 0x1ce5163ff60>"
      ]
     },
     "execution_count": 48,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.fit(x_train, y_train, validation_split=0.1, epochs=2, batch_size=128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 776847 samples, validate on 86317 samples\n",
      "Epoch 1/1\n",
      "776847/776847 [==============================] - 157s 202us/step - loss: 3.2409 - acc: 0.2078 - val_loss: 3.4523 - val_acc: 0.1768\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<keras.callbacks.History at 0x1ce5163fac8>"
      ]
     },
     "execution_count": 49,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.fit(x_train, y_train, validation_split=0.1, epochs=1, batch_size=512)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 776847 samples, validate on 86317 samples\n",
      "Epoch 1/1\n",
      "776847/776847 [==============================] - 179s 230us/step - loss: 3.2358 - acc: 0.2087 - val_loss: 3.4597 - val_acc: 0.1725\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<keras.callbacks.History at 0x1ce5163feb8>"
      ]
     },
     "execution_count": 50,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.fit(x_train, y_train, validation_split=0.1, epochs=1, batch_size=512)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n",
      "50000\n",
      "100000\n",
      "150000\n"
     ]
    }
   ],
   "source": [
    "pred = model.predict(res)\n",
    "result = np.argmax(pred, axis = 1)\n",
    "\n",
    "# 写入文件\n",
    "csvFile = open('C:/Users/Kai/Desktop/171840708_IntroDM_MiningChallenge/mining-challenge-for-nju-introdm-2019/Mining Challenge Dataset/3.csv','w', newline='', encoding='UTF-8') # 设置newline，否则两行之间会空一行\n",
    "writer = csv.writer(csvFile)\n",
    "\n",
    "writer.writerow(['ID', 'Expected'])\n",
    "for i in range(len(result)):\n",
    "    if i % 50000 == 0:\n",
    "        print(i)\n",
    "    writer.writerow([int(i), int(result[i])])\n",
    "    \n",
    "csvFile.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
